Split ================= 将一个张量(Tensor)沿着指定的轴(axis)拆分为多个子张量。子张量在指定轴上的大小由 ``split_sizes`` 数组决定。 .. math:: \text{input.shape} = [d_0, d_1, \dots, d_{axis}, \dots, d_{n-1}] .. math:: \text{对于第 } j \text{ 个输出: } output_j\text{.shape} = [d_0, d_1, \dots, split\_sizes[j], \dots, d_{n-1}] 输入: - **input** - 输入数据起始地址。 - **axis** - 指定拆分的维度轴。 - **input_shape** - 输入张量的形状数组地址。 - **input_ndim** - 输入张量的维度数。 - **num_split** - 拆分出的子张量个数。 - **split_sizes** - 一个数组,包含每个子张量在拆分轴上的长度。 - **core_mask(int, 可选)** - 核掩码(仅适用于共享存储版本)。 输出: - **outputs** - 指针数组地址,其中每个元素指向一个子张量的存储地址。 支持平台: ``FT78NE`` ``MT7004`` .. note:: - FT78NE 支持 int8, int16, int32, fp32, fp64, cplx64, cplx128 - MT7004 支持 fp16, fp32, int16, int32, cplx64 - ``split_sizes`` 的元素之和必须等于输入张量在 ``axis`` 维度的长度。 - 对于复数类型(cplx64 / cplx128),拆分逻辑与实数一致,但需注意地址偏移按复数对计算。 **共享存储版本:** .. c:function:: void i8_split_s(int8_t* input, int8_t* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes, int core_mask) .. c:function:: void i16_split_s(int16_t* input, int16_t* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes, int core_mask) .. c:function:: void i32_split_s(int32_t* input, int32_t* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes, int core_mask) .. c:function:: void hp_split_s(half* input, half* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes, int core_mask) .. c:function:: void fp_split_s(float* input, float* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes, int core_mask) .. c:function:: void dp_split_s(double* input, double* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes, int core_mask) .. c:function:: void c64_split_s(float* input, float* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes, int core_mask) .. c:function:: void c128_split_s(double* input, double* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes, int core_mask) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 17 //FT78NE示例(共享存储) #include #include "78NE/utils.h" int main(int argc, char* argv[]) { float *input = (float *)0xA0000000; float *out0 = (float *)0xB0000000; float *out1 = (float *)0xB1000000; float *outputs[] = { out0, out1 }; int input_shape[] = { 2, 10, 4 }; int split_sizes[] = { 6, 4 }; int axis = 1; int input_ndim = 3; int num_split = 2; int core_mask = 0b1011; fp_split_s(input, outputs, axis, input_shape, input_ndim, num_split, split_sizes, core_mask); return 0; } **私有存储版本:** .. c:function:: void i8_split_p(int8_t* input, int8_t* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes) .. c:function:: void i16_split_p(int16_t* input, int16_t* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes) .. c:function:: void i32_split_p(int32_t* input, int32_t* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes) .. c:function:: void hp_split_p(half* input, half* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes) .. c:function:: void fp_split_p(float* input, float* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes) .. c:function:: void dp_split_p(double* input, double* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes) .. c:function:: void c64_split_p(float* input, float* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes) .. c:function:: void c128_split_p(double* input, double* outputs[], int axis, int* input_shape, int input_ndim, int num_split, int* split_sizes) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 12 //MT7004 示例 #include int main(int argc, char* argv[]) { float *input = (float *)0x10000000; // 私有存储地址 float *out0 = (float *)0x10010000; float *out1 = (float *)0x10020000; float *outputs[] = { out0, out1 }; int input_shape[] = { 20, 10 }; int split_sizes[] = { 10, 10 }; int axis = 0; fp_split_p(input, outputs, axis, input_shape, 2, 2, split_sizes); return 0; }